
import torch
import numpy as np


def data_gauss_clusters(n, d, center_dist = 1):

    a_mean = center_dist * (1/pow(d,1/2)) * np.ones(d)
    b_mean = -1 * center_dist * (1/pow(d,1/2)) * np.ones(d)

    A_data = np.random.normal(a_mean, size=(n//2, d))
    B_data = np.random.normal(b_mean, size=(n - n//2, d))

    data = np.concatenate([A_data, B_data], axis=0)
    labels = np.concatenate([np.ones(n//2), np.zeros(n - n//2)], axis=0)

    return torch.tensor(data, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)

def data_orthogonal(n, d):

    X = np.zeros((n, d))
    y = np.zeros(n)

    for i in range(n - 1):
        dim_index = i % d
        if i % 2 == 0:
            X[i, dim_index //2] = 1
            y[i] = 1
        else:
            X[i, dim_index//2] = -1
            y[i] = 0

    X[n -1, 0] = 1
    y[n-1] = 0

    # print(X @ np.ones(d) - y + 1/2)
    print(X)
    print(y)

    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

def data_s_block(n, d, s):

    assert s <= min(n, d)

    X = np.zeros((n, d))
    y = np.zeros(n)

    for i in range(s):
        if i % 2 == 0:
            X[i, 0] = 1
            y[i] = 1
        else:
            X[i, 0] = -1
            y[i] = 0

    for i in range(s, n):
        if i % 2 == 0:
            X[i, i //2] = 1
            y[i] = 1
        else:
            X[i, i//2] = 1
            y[i] = 0

    # print(X)
    # print(y)
    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)
    y = y.unsqueeze(1)

    return X, y

def sample_spherical(npoints, ndim=3):
    vec = np.random.randn(ndim, npoints)
    vec /= np.linalg.norm(vec, axis=0)
    return vec.T

def data_wu(n, d, k):

    X = sample_spherical(n, d)
    V = sample_spherical(k, d)
    
    y = np.sum(np.maximum(X @ V.T, 0), axis = 1)

    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)
    y = y.unsqueeze(1)

    return X, y

def data_temp(n, d):
    X = torch.tensor([1,2], dtype=torch.float32)
    y = torch.tensor([1], dtype=torch.float32)
    X = X.unsqueeze(1)
    X = X.T
    y = y.unsqueeze(1)

    return X, y

class SynthData:

    def __init__(self, n, d, gen_func, **kwargs):
            
        data, labels = gen_func(n, d, **kwargs)

        self.data = torch.utils.data.TensorDataset(data, labels)

    def get_loader(self, batch_size):
        self.data_loader = torch.utils.data.DataLoader(self.data, batch_size, shuffle=True)
        return self.data_loader

